{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import email_read_util\n",
    "from datasketch import MinHash, MinHashLSH"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download 2007 TREC Public Spam Corpus\n",
    "1. Read the \"Agreement for use\"\n",
    "   https://plg.uwaterloo.ca/~gvcormac/treccorpus07/\n",
    "\n",
    "2. Download 255 MB Corpus (trec07p.tgz) and untar into the 'chapter1/datasets' directory\n",
    "\n",
    "3. Check that the below paths for 'DATA_DIR' and 'LABELS_FILE' exist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = 'datasets/trec07p/data/'\n",
    "LABELS_FILE = 'datasets/trec07p/full/index'\n",
    "TRAINING_SET_RATIO = 0.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read the labels\n",
    "with open(LABELS_FILE) as f:\n",
    "    for line in f:\n",
    "        line = line.strip()\n",
    "        label, key = line.split()\n",
    "        labels[key.split('/')[-1]] = 1 if label.lower() == 'ham' else 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split corpus into train and test sets\n",
    "filelist = os.listdir(DATA_DIR)\n",
    "X_train = filelist[:int(len(filelist)*TRAINING_SET_RATIO)]\n",
    "X_test = filelist[int(len(filelist)*TRAINING_SET_RATIO):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract only spam files for inserting into the LSH matcher\n",
    "spam_files = [x for x in X_train if labels[x] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize MinHashLSH matcher with a Jaccard \n",
    "# threshold of 0.5 and 128 MinHash permutation functions\n",
    "lsh = MinHashLSH(threshold=0.5, num_perm=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Populate the LSH matcher with training spam MinHashes\n",
    "for idx, f in enumerate(spam_files):\n",
    "    minhash = MinHash(num_perm=128)\n",
    "    stems = email_read_util.load(os.path.join(DATA_DIR, f))\n",
    "    if len(stems) < 2: continue\n",
    "    for s in stems:\n",
    "        minhash.update(s.encode('utf-8'))\n",
    "    lsh.insert(f, minhash)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lsh_predict_label(stems):\n",
    "    '''\n",
    "    Queries the LSH matcher and returns:\n",
    "        0 if predicted spam\n",
    "        1 if predicted ham\n",
    "       -1 if parsing error\n",
    "    '''\n",
    "    minhash = MinHash(num_perm=128)\n",
    "    if len(stems) < 2:\n",
    "        return -1\n",
    "    for s in stems:\n",
    "        minhash.update(s.encode('utf-8'))\n",
    "    matches = lsh.query(minhash)\n",
    "    if matches:\n",
    "        return 0\n",
    "    else:\n",
    "        return 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "fp = 0\n",
    "tp = 0\n",
    "fn = 0\n",
    "tn = 0\n",
    "\n",
    "for filename in X_test:\n",
    "    path = os.path.join(DATA_DIR, filename)\n",
    "    if filename in labels:\n",
    "        label = labels[filename]\n",
    "        stems = email_read_util.load(path)\n",
    "        if not stems:\n",
    "            continue\n",
    "        pred = lsh_predict_label(stems)\n",
    "        if pred == -1:\n",
    "            continue\n",
    "        elif pred == 0:\n",
    "            if label == 1:\n",
    "                fp = fp + 1\n",
    "            else:\n",
    "                tp = tp + 1\n",
    "        elif pred == 1:\n",
    "            if label == 1:\n",
    "                tn = tn + 1\n",
    "            else:\n",
    "                fn = fn + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table><tr><td>7350</td><td>136</td></tr><tr><td>2241</td><td>11038</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import HTML, display\n",
    "conf_matrix = [[tn, fp],\n",
    "               [fn, tp]]\n",
    "display(HTML('<table><tr>{}</tr></table>'.format(\n",
    "    '</tr><tr>'.join('<td>{}</td>'.format(\n",
    "        '</td><td>'.join(str(_) for _ in row)) \n",
    "                     for row in conf_matrix))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table><tr><td>35.4%</td><td>0.7%</td></tr><tr><td>10.8%</td><td>53.2%</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "count = tn + tp + fn + fp\n",
    "percent_matrix = [[\"{:.1%}\".format(tn/count), \"{:.1%}\".format(fp/count)],\n",
    "                  [\"{:.1%}\".format(fn/count), \"{:.1%}\".format(tp/count)]]\n",
    "display(HTML('<table><tr>{}</tr></table>'.format(\n",
    "    '</tr><tr>'.join('<td>{}</td>'.format(\n",
    "        '</td><td>'.join(str(_) for _ in row)) \n",
    "                     for row in percent_matrix))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification accuracy: 88.6%\n"
     ]
    }
   ],
   "source": [
    "print(\"Classification accuracy: {}\".format(\"{:.1%}\".format((tp+tn)/count)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
